{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Analyze Errors and Explore Interpretability of Models\n",
    "\n",
    "This notebook demonstrates how to use the Responsible AI Widget's Error Analysis dashboard to understand a model trained on the multiclass wine dataset. The goal of this sample notebook is to classify types of wine with scikit-learn and explore model errors and explanations:\n",
    "\n",
    "1. Train an SVM classification model using Scikit-learn\n",
    "2. Run Interpret-Community's 'explain_model' globally and locally to generate model explanations.\n",
    "3. Visualize model errors and global and local explanations with the Error Analysis visualization dashboard."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Required Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install --upgrade interpret-community\n",
    "# %pip install --upgrade raiwidgets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_wine\n",
    "from sklearn import svm\n",
    "\n",
    "# Imports for SHAP MimicExplainer with LightGBM surrogate model\n",
    "from interpret.ext.blackbox import MimicExplainer\n",
    "from interpret.ext.glassbox import LGBMExplainableModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the wine data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wine = load_wine()\n",
    "X = wine['data']\n",
    "y = wine['target']\n",
    "classes = wine['target_names']\n",
    "feature_names = wine['feature_names']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data into train and test\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a SVM classification model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
    "model = clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Notice the model makes a fair number of errors\n",
    "print(\"number of errors on test dataset: \" + str(sum(model.predict(X_test) != y_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load simple ErrorAnalysis view without explanations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from raiwidgets import ErrorAnalysisDashboard\n",
    "predictions = model.predict(X_test)\n",
    "ErrorAnalysisDashboard(dataset=X_test, true_y=y_test, features=feature_names, pred_y=predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a surrogate model to explain the original blackbox model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret_community.common.constants import ModelTask\n",
    "# Train the LightGBM surrogate model using MimicExplaner\n",
    "model_task = ModelTask.Classification\n",
    "explainer = MimicExplainer(model, X_train, LGBMExplainableModel,\n",
    "                           augment_data=True, max_num_of_augmentations=10,\n",
    "                           features=feature_names, classes=classes, model_task=model_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate global explanations\n",
    "Explain overall model predictions (global explanation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
    "# X_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
    "global_explanation = explainer.explain_global(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print out a dictionary that holds the sorted feature importance names and values\n",
    "print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analyze model errors and explanations using Error Analysis dashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from raiwidgets import ErrorAnalysisDashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "ErrorAnalysisDashboard(global_explanation, model, dataset=X_test, true_y=y_test)"
   ]
  }
 ],
 "metadata": {
  "authors": [
   {
    "name": "mesameki"
   }
  ],
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
